#This module creates the agent class of the irrigation scheduler ---- by Bernard Twum Agyeman(agyeman@ualberta.ca)
#For this agent, the reference evapotranspiration and the crop coefficient are treated as states
from tensorforce.environments import Environment 
from Parameters_1D import *
import numpy as np
from Model_Simulation import SimulateModel
totalDepth, axialNodes, totalNodes = spatialVariables_1D_act() 

class IrrigationScheduler(Environment):    
    def __init__(self):
        self.lowerBound_zone = 0.209 #0.30-0.65*(0.30-0.16)
        self.upperBound_zone = 0.300
        self.maximum_state_loam = 0.39
        self.minimum_state_loam = 0.10
        self.refEvap_max = 9.0
        self.refEvap_min = 1.04
        self.cropCoeff_max = 1.25
        self.cropCoeff_min = 0.20
        self.rootingDepth_max = 1.00
        self.rootingDepth_min = 0.25
        self.LAI_max = 1.00
        self.LAI_min = 0.00
        self.irrig_max = -0.0050112 #m
        self.irrig_min = -0.062208 #m
        self.weight_bin = 1000
        self.weight_irrig = -9000
        self.weight_LZ = 2e07
        self.weight_UZ = 2.2e07

        self.upperBound_state = np.zeros(axialNodes + 4)
        self.state_UB = np.zeros(axialNodes + 4)
        self.lowerBound_state = np.zeros(axialNodes + 4)
        self.state_LB = np.zeros(axialNodes + 4)
        self.upperBound_state[0:axialNodes] = (self.maximum_state_loam - 0.04)*np.ones(axialNodes)
        self.upperBound_state[axialNodes] = self.refEvap_max
        self.upperBound_state[axialNodes + 1] = self.cropCoeff_max
        self.upperBound_state[axialNodes + 2] = self.rootingDepth_max
        self.upperBound_state[axialNodes + 3] = self.LAI_max
        self.state_UB[0:axialNodes] = (self.maximum_state_loam)*np.ones(axialNodes)
        self.state_UB[axialNodes] = self.refEvap_max
        self.state_UB[axialNodes + 1] = self.cropCoeff_max
        self.state_UB[axialNodes + 2] = self.rootingDepth_max
        self.state_UB[axialNodes + 3] = self.LAI_max
        self.lowerBound_state[0:axialNodes] = (self.minimum_state_loam + 0.03)*np.ones(axialNodes)
        self.lowerBound_state[axialNodes] = self.refEvap_min
        self.lowerBound_state[1*axialNodes + 1] = self.cropCoeff_min
        self.lowerBound_state[1*axialNodes + 2] = self.rootingDepth_min
        self.lowerBound_state[1*axialNodes + 3] = self.LAI_min
        self.state_LB[0:axialNodes] = (self.minimum_state_loam)*np.ones(axialNodes)
        self.state_LB[axialNodes] = self.refEvap_min
        self.state_LB[axialNodes + 1] = self.cropCoeff_min
        self.state_LB[axialNodes + 2] = self.rootingDepth_min
        self.state_LB[axialNodes + 3] = self.LAI_min
        self.current_state = np.random.uniform(self.lowerBound_state, self.upperBound_state)
        super().__init__()
        pass 
    
    def states(self):
        return dict(type='float', shape=(axialNodes+4,), min_value=self.state_LB, max_value=self.state_UB)

    def actions(self):
        return {"irrigationDecision": dict(type='int', num_values=2),
                "irrigationRate": dict(type='float', shape=(1,), min_value=self.irrig_min, max_value=self.irrig_max)}
    
    def set_LAI_factor(self, lai_factor):
        self.current_state[-1] = lai_factor

    def set_weighting_parameters(self):
        if self.current_state[-2] == 0.50:
            self.weight_bin = 1000
            self.weight_irrig = -9000
        else:
            self.weight_bin = 9000      #Can be tuned
            self.weight_irrig = -1000   #Can be tuned
    
    def set_rooting_depth(self, episodeNumber):
        #Set the rooting depth based on the episode number
        if episodeNumber in range(0, 25000):
            self.current_state[-2] = 0.50
        elif episodeNumber in range(25000, 50000):
            self.current_state[-2] = 1.00
        elif episodeNumber in range(50000, 75000):
            self.current_state[-2] = 0.50
        elif episodeNumber in range(75000, 100000):
            self.current_state[-2] = 1.00
        elif episodeNumber in range(100000, 125000):
            self.current_state[-2] = 0.50
        elif episodeNumber in range(125000, 150000):
            self.current_state[-2] = 1.00
        elif episodeNumber in range(150000, 175000):
            self.current_state[-2] = 0.50
        elif episodeNumber in range(175000, 200000):
            self.current_state[-2] = 1.00
        elif episodeNumber in range(200000, 225000):
            self.current_state[-2] = 0.50
        elif episodeNumber in range(225000, 250000):
            self.current_state[-2] = 1.00
        elif episodeNumber in range(250000, 275000):
            self.current_state[-2] = 0.50
        elif episodeNumber in range(275000, 300000):
            self.current_state[-2] = 1.00
        elif episodeNumber in range(300000, 325000):
            self.current_state[-2] = 0.50
        elif episodeNumber in range(325000, 350000):
            self.current_state[-2] = 1.00
        elif episodeNumber in range(350000, 375000):
            self.current_state[-2] = 0.50
        elif episodeNumber in range(375000, 400000):
            self.current_state[-2] = 1.00
        elif episodeNumber in range(400000, 425000):
            self.current_state[-2] = 0.50
        elif episodeNumber in range(425000, 450000):
            self.current_state[-2] = 1.00
        elif episodeNumber in range(450000, 475000):
            self.current_state[-2] = 0.50
        elif episodeNumber in range(475000, 500000):
            self.current_state[-2] = 1.00
        return

    def compute_rootZone_moisture(self):
        if self.current_state[-2] == 0.50:
            rootZone_vol_moist = 0.10*((1/6)*(self.current_state[10] + self.current_state[11] + self.current_state[12] + self.current_state[13] + self.current_state[14] + self.current_state[15])) +\
            0.20*((1/6)*(self.current_state[15] + self.current_state[16] + self.current_state[17] + self.current_state[18] + self.current_state[19] + self.current_state[20])) +\
            0.30*((1/6)*(self.current_state[20] + self.current_state[21] + self.current_state[22] + self.current_state[23] + self.current_state[24] + self.current_state[25])) + \
            0.40*((1/6)*(self.current_state[25] + self.current_state[26] + self.current_state[27] + self.current_state[28] + self.current_state[29] + self.current_state[30]))
        else:
            rootZone_vol_moist = 0.10*((1/6)*(self.current_state[0] + self.current_state[1] + self.current_state[2] + self.current_state[3] + self.current_state[4] + self.current_state[5]))  + \
            0.20*((1/6)*(self.current_state[5] + self.current_state[6] + self.current_state[7] + self.current_state[8] + self.current_state[9] + self.current_state[10])) + \
            0.30*((1/11)*(self.current_state[10] + self.current_state[11] + self.current_state[12] + self.current_state[13] + self.current_state[14] + self.current_state[15] + self.current_state[16] + self.current_state[17] + self.current_state[18] + self.current_state[19] + self.current_state[20])) + \
            0.40*((1/11)*(self.current_state[20] + self.current_state[21] + self.current_state[22] + self.current_state[23] + self.current_state[24] + self.current_state[25] + self.current_state[26] + self.current_state[27] + self.current_state[28] + self.current_state[29] + self.current_state[30]))
        return rootZone_vol_moist

    def max_episode_timesteps(self):
        return super().max_episode_timesteps()
    
    def close(self):
        super().close()

    def reset(self):
        self.timestep=0  
        self.current_state = np.random.uniform(self.lowerBound_state, self.upperBound_state)
        return self.current_state

    def response(self, actions):
        return SimulateModel(state=self.current_state[0:axialNodes], action=actions["irrigationDecision"]*actions["irrigationRate"], 
                        refEvap = self.current_state[axialNodes], cropCoeff = self.current_state[axialNodes+1], rooting_depth=self.current_state[axialNodes+2], lai_factor=self.current_state[-1])

    def reward_compute(self, actions):
        #compute the root zone soil moisture based on the current rooting depth
        rootZone_vol_moist =self.compute_rootZone_moisture()   
        self.set_weighting_parameters() #Chose the appropriate weighting parameters based on the rooting depth   
        #Implement the zone tracking
        if self.lowerBound_zone<=rootZone_vol_moist<=self.upperBound_zone: # When the capillary pressure head lies in the zone
            reward_zone = 0    
        elif rootZone_vol_moist<self.lowerBound_zone: #When the lower bound of the zone is violated
            reward_zone = self.weight_LZ*(np.abs(rootZone_vol_moist - self.lowerBound_zone))**2
        else:
            reward_zone = self.weight_UZ*(np.abs(rootZone_vol_moist - self.upperBound_zone))**2
        reward_bin = self.weight_bin*actions["irrigationDecision"]
        reward_irrig = self.weight_irrig*actions["irrigationRate"]
        reward_total = -1*(reward_zone + reward_bin + reward_irrig) 
        return reward_total[0]

    def execute(self, actions):
        assert actions["irrigationDecision"] == 0 or actions["irrigationDecision"] == 1 
        assert actions["irrigationRate"] <= self.irrig_max or actions["irrigationRate"] >= self.irrig_min 
        self.timestep+=1
        self.current_state = self.response(actions)
        reward = self.reward_compute(actions)
        terminal = False 
        return self.current_state, terminal, reward